import os
import random
import numpy as np
import torch
import torch.nn as nn


def set_seeds(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class SoftDiceLoss(nn.Module):
    def __init__(self, smooth=1., dims=(-2,-1)):

        super(SoftDiceLoss, self).__init__()
        self.smooth = smooth
        self.dims = dims
    
    def forward(self, x, y):
        tp = (x * y).sum(self.dims)
        fp = (x * (1 - y)).sum(self.dims)
        fn = ((1 - x) * y).sum(self.dims)
        
        dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)
        dc = dc.mean()
        return 1 - dc
    
bce_fn = nn.BCEWithLogitsLoss()
dice_fn = SoftDiceLoss()

def loss_fn(y_pred, y_true, ratio=0.8, hard=False) -> torch.Tensor:
    bce = bce_fn(y_pred, y_true)
    if hard:
        dice = dice_fn((y_pred.sigmoid()).float() > 0.5, y_true)
    else:
        dice = dice_fn(y_pred.sigmoid(), y_true)
    return ratio*bce + (1-ratio)*dice


@torch.no_grad()
def validation(model, loader, loss_fn):
    losses = []
    model.eval()
    device = next(model.parameters()).device
    for image, target in loader:
        image, target = image.to(device), target.float().to(device)
        output = model(image)
        loss = loss_fn(output, target)
        losses.append(loss.item())
        
    return np.array(losses).mean()

def np_dice_score(probability, mask):
    p = probability.reshape(-1)
    t = mask.reshape(-1)

    p = p>0.5
    t = t>0.5
    uion = p.sum() + t.sum()
    
    overlap = (p*t).sum()
    dice = 2*overlap/(uion+0.001)
    return dice

def validation_acc(model, val_loader, device=None):
    if device is None:
        device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    val_probability, val_mask = [], []
    model.eval()
    with torch.no_grad():
        for image, target in val_loader:
            image, target = image.to(device), target.float().to(device)
            output = model(image)
            
            output_ny = output.sigmoid().data.cpu().numpy()
            target_np = target.data.cpu().numpy()
            
            val_probability.append(output_ny)
            val_mask.append(target_np)
            
    val_probability = np.concatenate(val_probability)
    val_mask = np.concatenate(val_mask)
    
    return np_dice_score(val_probability, val_mask)

